{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 文本分类实例"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step1 导入相关包"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "d:\\Miniconda\\envs\\geo\\lib\\site-packages\\numpy\\_distributor_init.py:30: UserWarning: loaded more than 1 DLL from .libs:\n",
      "d:\\Miniconda\\envs\\geo\\lib\\site-packages\\numpy\\.libs\\libopenblas.FB5AE2TYXYH2IJRDKGDGQ3XBKLKTF43H.gfortran-win_amd64.dll\n",
      "d:\\Miniconda\\envs\\geo\\lib\\site-packages\\numpy\\.libs\\libopenblas64__v0.3.21-gcc_10_3_0.dll\n",
      "  warnings.warn(\"loaded more than 1 DLL from .libs:\"\n"
     ]
    }
   ],
   "source": [
    "from transformers import AutoTokenizer, AutoModelForSequenceClassification"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step2 加载数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>label</th>\n",
       "      <th>review</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1</td>\n",
       "      <td>距离川沙公路较近,但是公交指示不对,如果是\"蔡陆线\"的话,会非常麻烦.建议用别的路线.房间较...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>商务大床房，房间很大，床有2M宽，整体感觉经济实惠不错!</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>1</td>\n",
       "      <td>早餐太差，无论去多少人，那边也不加食品的。酒店应该重视一下这个问题了。房间本身很好。</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>1</td>\n",
       "      <td>宾馆在小街道上，不大好找，但还好北京热心同胞很多~宾馆设施跟介绍的差不多，房间很小，确实挺小...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>1</td>\n",
       "      <td>CBD中心,周围没什么店铺,说5星有点勉强.不知道为什么卫生间没有电吹风</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7761</th>\n",
       "      <td>0</td>\n",
       "      <td>尼斯酒店的几大特点：噪音大、环境差、配置低、服务效率低。如：1、隔壁歌厅的声音闹至午夜3点许...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7762</th>\n",
       "      <td>0</td>\n",
       "      <td>盐城来了很多次，第一次住盐阜宾馆，我的确很失望整个墙壁黑咕隆咚的，好像被烟熏过一样家具非常的...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7763</th>\n",
       "      <td>0</td>\n",
       "      <td>看照片觉得还挺不错的，又是4星级的，但入住以后除了后悔没有别的，房间挺大但空空的，早餐是有但...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7764</th>\n",
       "      <td>0</td>\n",
       "      <td>我们去盐城的时候那里的最低气温只有4度，晚上冷得要死，居然还不开空调，投诉到酒店客房部，得到...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7765</th>\n",
       "      <td>0</td>\n",
       "      <td>说实在的我很失望，之前看了其他人的点评后觉得还可以才去的，结果让我们大跌眼镜。我想这家酒店以...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>7766 rows × 2 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "      label                                             review\n",
       "0         1  距离川沙公路较近,但是公交指示不对,如果是\"蔡陆线\"的话,会非常麻烦.建议用别的路线.房间较...\n",
       "1         1                       商务大床房，房间很大，床有2M宽，整体感觉经济实惠不错!\n",
       "2         1         早餐太差，无论去多少人，那边也不加食品的。酒店应该重视一下这个问题了。房间本身很好。\n",
       "3         1  宾馆在小街道上，不大好找，但还好北京热心同胞很多~宾馆设施跟介绍的差不多，房间很小，确实挺小...\n",
       "4         1               CBD中心,周围没什么店铺,说5星有点勉强.不知道为什么卫生间没有电吹风\n",
       "...     ...                                                ...\n",
       "7761      0  尼斯酒店的几大特点：噪音大、环境差、配置低、服务效率低。如：1、隔壁歌厅的声音闹至午夜3点许...\n",
       "7762      0  盐城来了很多次，第一次住盐阜宾馆，我的确很失望整个墙壁黑咕隆咚的，好像被烟熏过一样家具非常的...\n",
       "7763      0  看照片觉得还挺不错的，又是4星级的，但入住以后除了后悔没有别的，房间挺大但空空的，早餐是有但...\n",
       "7764      0  我们去盐城的时候那里的最低气温只有4度，晚上冷得要死，居然还不开空调，投诉到酒店客房部，得到...\n",
       "7765      0  说实在的我很失望，之前看了其他人的点评后觉得还可以才去的，结果让我们大跌眼镜。我想这家酒店以...\n",
       "\n",
       "[7766 rows x 2 columns]"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import pandas as pd\n",
    "pretrain_model_dir = \"/data/models/huggingface/rbt3\"\n",
    "save_dir = '/data/logs/data_parallel_2GPUS'\n",
    "data = pd.read_csv(\"./ChnSentiCorp_htl_all.csv\")\n",
    "data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>label</th>\n",
       "      <th>review</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1</td>\n",
       "      <td>距离川沙公路较近,但是公交指示不对,如果是\"蔡陆线\"的话,会非常麻烦.建议用别的路线.房间较...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>商务大床房，房间很大，床有2M宽，整体感觉经济实惠不错!</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>1</td>\n",
       "      <td>早餐太差，无论去多少人，那边也不加食品的。酒店应该重视一下这个问题了。房间本身很好。</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>1</td>\n",
       "      <td>宾馆在小街道上，不大好找，但还好北京热心同胞很多~宾馆设施跟介绍的差不多，房间很小，确实挺小...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>1</td>\n",
       "      <td>CBD中心,周围没什么店铺,说5星有点勉强.不知道为什么卫生间没有电吹风</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7761</th>\n",
       "      <td>0</td>\n",
       "      <td>尼斯酒店的几大特点：噪音大、环境差、配置低、服务效率低。如：1、隔壁歌厅的声音闹至午夜3点许...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7762</th>\n",
       "      <td>0</td>\n",
       "      <td>盐城来了很多次，第一次住盐阜宾馆，我的确很失望整个墙壁黑咕隆咚的，好像被烟熏过一样家具非常的...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7763</th>\n",
       "      <td>0</td>\n",
       "      <td>看照片觉得还挺不错的，又是4星级的，但入住以后除了后悔没有别的，房间挺大但空空的，早餐是有但...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7764</th>\n",
       "      <td>0</td>\n",
       "      <td>我们去盐城的时候那里的最低气温只有4度，晚上冷得要死，居然还不开空调，投诉到酒店客房部，得到...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7765</th>\n",
       "      <td>0</td>\n",
       "      <td>说实在的我很失望，之前看了其他人的点评后觉得还可以才去的，结果让我们大跌眼镜。我想这家酒店以...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>7765 rows × 2 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "      label                                             review\n",
       "0         1  距离川沙公路较近,但是公交指示不对,如果是\"蔡陆线\"的话,会非常麻烦.建议用别的路线.房间较...\n",
       "1         1                       商务大床房，房间很大，床有2M宽，整体感觉经济实惠不错!\n",
       "2         1         早餐太差，无论去多少人，那边也不加食品的。酒店应该重视一下这个问题了。房间本身很好。\n",
       "3         1  宾馆在小街道上，不大好找，但还好北京热心同胞很多~宾馆设施跟介绍的差不多，房间很小，确实挺小...\n",
       "4         1               CBD中心,周围没什么店铺,说5星有点勉强.不知道为什么卫生间没有电吹风\n",
       "...     ...                                                ...\n",
       "7761      0  尼斯酒店的几大特点：噪音大、环境差、配置低、服务效率低。如：1、隔壁歌厅的声音闹至午夜3点许...\n",
       "7762      0  盐城来了很多次，第一次住盐阜宾馆，我的确很失望整个墙壁黑咕隆咚的，好像被烟熏过一样家具非常的...\n",
       "7763      0  看照片觉得还挺不错的，又是4星级的，但入住以后除了后悔没有别的，房间挺大但空空的，早餐是有但...\n",
       "7764      0  我们去盐城的时候那里的最低气温只有4度，晚上冷得要死，居然还不开空调，投诉到酒店客房部，得到...\n",
       "7765      0  说实在的我很失望，之前看了其他人的点评后觉得还可以才去的，结果让我们大跌眼镜。我想这家酒店以...\n",
       "\n",
       "[7765 rows x 2 columns]"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data = data.dropna()\n",
    "data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step3 创建Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import Dataset\n",
    "\n",
    "class MyDataset(Dataset):\n",
    "\n",
    "    def __init__(self) -> None:\n",
    "        super().__init__()\n",
    "        self.data = pd.read_csv(\"./ChnSentiCorp_htl_all.csv\")\n",
    "        self.data = self.data.dropna()\n",
    "\n",
    "    def __getitem__(self, index):\n",
    "        return self.data.iloc[index][\"review\"], self.data.iloc[index][\"label\"]\n",
    "    \n",
    "    def __len__(self):\n",
    "        return len(self.data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "('距离川沙公路较近,但是公交指示不对,如果是\"蔡陆线\"的话,会非常麻烦.建议用别的路线.房间较为简单.', 1)\n",
      "('商务大床房，房间很大，床有2M宽，整体感觉经济实惠不错!', 1)\n",
      "('早餐太差，无论去多少人，那边也不加食品的。酒店应该重视一下这个问题了。房间本身很好。', 1)\n",
      "('宾馆在小街道上，不大好找，但还好北京热心同胞很多~宾馆设施跟介绍的差不多，房间很小，确实挺小，但加上低价位因素，还是无超所值的；环境不错，就在小胡同内，安静整洁，暖气好足-_-||。。。呵还有一大优势就是从宾馆出发，步行不到十分钟就可以到梅兰芳故居等等，京味小胡同，北海距离好近呢。总之，不错。推荐给节约消费的自助游朋友~比较划算，附近特色小吃很多~', 1)\n",
      "('CBD中心,周围没什么店铺,说5星有点勉强.不知道为什么卫生间没有电吹风', 1)\n"
     ]
    }
   ],
   "source": [
    "dataset = MyDataset()\n",
    "for i in range(5):\n",
    "    print(dataset[i])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step4 划分数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(6989, 776)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from torch.utils.data import random_split\n",
    "\n",
    "\n",
    "trainset, validset = random_split(dataset, lengths=[0.9, 0.1])\n",
    "len(trainset), len(validset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "('由于房间里面的水是锈水，大堂副理免费升级到套间，由于我预定3天，讲好头2晚是478，第三晚是578。结果第三天直接告诉只能是678，经过与携程反映也没有效果，原因就是第三天有一批人来此开会，房间不够，故意把住客赶走。坚决再也不住这里了。态度十分差劲。', 0)\n",
      "('我去年的时候去过该酒店，那时候住的是无烟房，感觉还好，这回去住的是高级标准间，进去了房间就有一种怪味道要求前台换房，换过了以后还是多少有一点怪味。不过房间的墙壁特脏，后来想算了，就没在换，不过感觉今年的金鼎要比去年卫生环境差很多，装修也旧了！', 0)\n",
      "('在潮州可能也找不到更好的酒店了，感觉不如白玉兰。', 1)\n",
      "('坐落在香港的老城区，可以体验香港居民生活，门口交通很方便，如果时间不紧，坐叮当车很好呀！周围有很多小餐馆，早餐就在中远后面的南北嚼吃的，东西很不错。我们定的大床房，挺安静的，总体来说不错。前台结账没有银联！', 1)\n",
      "('刚去的时候有点难找,不过服务真是很不错哦~~而且金碧辉煌的~~!房间虽然不大不过我一个人住肯定够,很舒适而且窗外是海景和轮船呢~~嘿嘿..酒店就在海鲜街上,走不远还有皇后大道~交通呢..其实在香港交通都会很方便拉~酒店门口就有有轨电车~可以到地铁站也可以直接到中环旺角什么的..出门不远有个锅贴大王东西又好吃又便宜~~推荐！下次有机会还会考虑住这儿.!补充点评2008年7月24日：有空调!可以自己调温度！', 1)\n",
      "('环境不错，房间设施挺好，很有档次，下次还会考虑入住', 1)\n",
      "(\"八月初入住两天,感觉极好.幽雅的庭院式格局与众多的现代高楼酒店相比,给人一种苏州园林式的恬静与悠闲.服务快捷且有求必应.地处闹市却静如田园.喜欢酒吧的出门左手数步HARRIS'BAR晚间十点以后热闹非常,有菲籍BAND演奏.饮食与购物天堂仅一街之隔.通过携程订房有送雪碧及可乐各一.早餐质差,68元价无所值.酒店周围美食多多,可尽情享受.强力推荐!\", 1)\n",
      "('由于飞机晚点，半夜到酒店11点半左右。登记后无人提示酒店是否有无烟楼层，酒店不提供火柴，我半夜跑楼下去买的火机。房间没有吃的东西，水也没有，马桶也不舒服（座垫太小），半夜风大，窗户风吹的声音很大，（1115房间）。早餐20一份，自费，一碗皮蛋瘦肉粥、炒河粉、一杯牛奶，一个鸡蛋。吃饭地方偏贵，总体来说没有值得赞扬的地方。我住了2天半，最后半天上楼我出示房卡，清洁卫生的阿姨说要问楼下，一问就是3、4分钟。', 0)\n",
      "('2月10日刚刚住了,先定的2号楼的房间,进去一看吓了一跳!\"号称挂牌四星,简直不敢想,就连招待所都不如!天花板的三角板吊顶是用胶带纸粘住的!赶紧加假换房,住了一号楼,还看的过去,大概相当于二星级!还有一点不满的是,房间里的毛巾特别破旧,跟招待所的毛巾有的一比!总之,我是不会再住了.', 0)\n",
      "('比较朴实的一家酒店～～～还是上海衡山管理旗下，有些亲切感', 1)\n"
     ]
    }
   ],
   "source": [
    "for i in range(10):\n",
    "    print(trainset[i])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step5 创建Dataloader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(pretrain_model_dir)\n",
    "\n",
    "def collate_func(batch):\n",
    "    texts, labels = [], []\n",
    "    for item in batch:\n",
    "        texts.append(item[0])\n",
    "        labels.append(item[1])\n",
    "    inputs = tokenizer(texts, max_length=128, padding=\"max_length\", truncation=True, return_tensors=\"pt\")\n",
    "    inputs[\"labels\"] = torch.tensor(labels)\n",
    "    return inputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import DataLoader\n",
    "\n",
    "trainloader = DataLoader(trainset, batch_size=64, shuffle=True, collate_fn=collate_func)\n",
    "validloader = DataLoader(validset, batch_size=64, shuffle=False, collate_fn=collate_func)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'input_ids': tensor([[ 101, 6848,  749,  ...,    0,    0,    0],\n",
       "        [ 101, 2791, 7313,  ...,    0,    0,    0],\n",
       "        [ 101, 3146,  860,  ...,    0,    0,    0],\n",
       "        ...,\n",
       "        [ 101, 2456, 6379,  ...,    0,    0,    0],\n",
       "        [ 101, 2769, 3221,  ..., 6963, 3221,  102],\n",
       "        [ 101,  679, 7231,  ...,    0,    0,    0]]), 'token_type_ids': tensor([[0, 0, 0,  ..., 0, 0, 0],\n",
       "        [0, 0, 0,  ..., 0, 0, 0],\n",
       "        [0, 0, 0,  ..., 0, 0, 0],\n",
       "        ...,\n",
       "        [0, 0, 0,  ..., 0, 0, 0],\n",
       "        [0, 0, 0,  ..., 0, 0, 0],\n",
       "        [0, 0, 0,  ..., 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],\n",
       "        [1, 1, 1,  ..., 0, 0, 0],\n",
       "        [1, 1, 1,  ..., 0, 0, 0],\n",
       "        ...,\n",
       "        [1, 1, 1,  ..., 0, 0, 0],\n",
       "        [1, 1, 1,  ..., 1, 1, 1],\n",
       "        [1, 1, 1,  ..., 0, 0, 0]]), 'labels': tensor([0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 0,\n",
       "        1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1,\n",
       "        1, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1])}"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "next(enumerate(validloader))[1]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step6 创建模型及优化器"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at /data/models/huggingface/rbt3 and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    }
   ],
   "source": [
    "from torch.optim import Adam\n",
    "\n",
    "model = AutoModelForSequenceClassification.from_pretrained(pretrain_model_dir)\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    model = model.cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "############### 重要 將模型加載大兩張卡上\n",
    "model = torch.nn.DataParallel(model,device_ids=[0,1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer = Adam(model.parameters(), lr=2e-5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step7 训练与验证"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "\n",
    "def evaluate():\n",
    "    model.eval()\n",
    "    acc_num = 0\n",
    "    with torch.inference_mode():\n",
    "        for batch in validloader:\n",
    "            if torch.cuda.is_available():\n",
    "                batch = {k: v.cuda() for k, v in batch.items()}\n",
    "            output = model(**batch)\n",
    "            pred = torch.argmax(output.logits, dim=-1)\n",
    "            acc_num += (pred.long() == batch[\"labels\"].long()).float().sum()\n",
    "    return acc_num / len(validset)\n",
    "\n",
    "def train(epoch=3, log_step=100):\n",
    "    global_step = 0\n",
    "    for ep in range(epoch):\n",
    "        model.train()\n",
    "        start = time.time()\n",
    "        for batch in trainloader:\n",
    "            if torch.cuda.is_available():\n",
    "                batch = {k: v.cuda() for k, v in batch.items()}\n",
    "            optimizer.zero_grad()\n",
    "            output = model(**batch)\n",
    "            loss = output.loss.mean()\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            if global_step % log_step == 0:\n",
    "                print(f\"ep: {ep}, global_step: {global_step}, loss: {loss.mean().item()}\")\n",
    "            global_step += 1\n",
    "        acc = evaluate()\n",
    "        print(f\"ep: {ep}, acc: {acc}, time: {time.time() - start}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step8 模型训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "d:\\Miniconda\\envs\\geo\\lib\\site-packages\\torch\\nn\\parallel\\_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n",
      "d:\\Miniconda\\envs\\geo\\lib\\site-packages\\torch\\cuda\\nccl.py:15: UserWarning: PyTorch is not compiled with NCCL support\n",
      "  warnings.warn('PyTorch is not compiled with NCCL support')\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ep: 0, global_step: 0, loss: 0.7683874368667603\n",
      "ep: 0, global_step: 100, loss: 0.3217390477657318\n",
      "ep: 0, acc: 0.8698453307151794, time: 25.773720264434814\n",
      "ep: 1, global_step: 200, loss: 0.16806672513484955\n",
      "ep: 1, acc: 0.8595360517501831, time: 18.950586080551147\n",
      "ep: 2, global_step: 300, loss: 0.31102514266967773\n",
      "ep: 2, acc: 0.89304119348526, time: 18.99100351333618\n"
     ]
    }
   ],
   "source": [
    "train()  # 1m26.9s"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step9 模型预测"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "输入：我觉得这家酒店不错，饭很好吃！\n",
      "模型预测结果:好评！\n"
     ]
    }
   ],
   "source": [
    "sen = \"我觉得这家酒店不错，饭很好吃！\"\n",
    "id2_label = {0: \"差评！\", 1: \"好评！\"}\n",
    "model.eval()\n",
    "with torch.inference_mode():\n",
    "    inputs = tokenizer(sen, return_tensors=\"pt\")\n",
    "    inputs = {k: v.cuda() for k, v in inputs.items()}\n",
    "    logits = model(**inputs).logits\n",
    "    pred = torch.argmax(logits, dim=-1)\n",
    "    print(f\"输入：{sen}\\n模型预测结果:{id2_label.get(pred.item())}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "ename": "AttributeError",
     "evalue": "'DataParallel' object has no attribute 'config'",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mAttributeError\u001b[0m                            Traceback (most recent call last)",
      "Cell \u001b[1;32mIn[17], line 3\u001b[0m\n\u001b[0;32m      1\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtransformers\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m pipeline\n\u001b[1;32m----> 3\u001b[0m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mconfig\u001b[49m\u001b[38;5;241m.\u001b[39mid2label \u001b[38;5;241m=\u001b[39m id2_label\n\u001b[0;32m      4\u001b[0m pipe \u001b[38;5;241m=\u001b[39m pipeline(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtext-classification\u001b[39m\u001b[38;5;124m\"\u001b[39m, model\u001b[38;5;241m=\u001b[39mmodel, tokenizer\u001b[38;5;241m=\u001b[39mtokenizer, device\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0\u001b[39m)\n",
      "File \u001b[1;32md:\\Miniconda\\envs\\geo\\lib\\site-packages\\torch\\nn\\modules\\module.py:1265\u001b[0m, in \u001b[0;36mModule.__getattr__\u001b[1;34m(self, name)\u001b[0m\n\u001b[0;32m   1263\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m name \u001b[38;5;129;01min\u001b[39;00m modules:\n\u001b[0;32m   1264\u001b[0m         \u001b[38;5;28;01mreturn\u001b[39;00m modules[name]\n\u001b[1;32m-> 1265\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mAttributeError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m object has no attribute \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m.\u001b[39mformat(\n\u001b[0;32m   1266\u001b[0m     \u001b[38;5;28mtype\u001b[39m(\u001b[38;5;28mself\u001b[39m)\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m, name))\n",
      "\u001b[1;31mAttributeError\u001b[0m: 'DataParallel' object has no attribute 'config'"
     ]
    }
   ],
   "source": [
    "from transformers import pipeline\n",
    "\n",
    "model.config.id2label = id2_label\n",
    "pipe = pipeline(\"text-classification\", model=model, tokenizer=tokenizer, device=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pipe(sen)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "transformers",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
